{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/milvus-io/bootcamp/blob/master/bootcamp/RAG/advanced_rag/hybrid_and_rerank_with_langchain.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Hybrid and rerank"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Google Colab preparation[optional]\n",
    "This is an optional step, if you want to run this notebook on Google Colab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "! git clone https://github.com/milvus-io/bootcamp.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import shutil\n",
    "src_dir = \"./bootcamp/bootcamp/RAG/advanced_rag/rag_utils\"\n",
    "dst_dir = \"./rag_utils\"\n",
    "shutil.copytree(src_dir, dst_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "! pip install --upgrade langchain langchain-community langchain_milvus langchain-voyageai langchain-openai rank_bm25 bs4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Please prepare you [VOYAGE_API_KEY](https://docs.voyageai.com/docs/api-key-and-installation#authentication-with-api-keys) and [OPENAI_API_KEY](https://openai.com/index/openai-api/) in your environment variables.\n",
    "![](imgs/colab_api_key2.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "### If you are running this notebook on Google Colab, you have to restart this session by `Cmd/Ctrl + M`, then press `.` to make the environment take effect."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from google.colab import userdata\n",
    "import os\n",
    "\n",
    "os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')\n",
    "os.environ['VOYAGE_API_KEY'] = userdata.get('VOYAGE_API_KEY')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "----\n",
    "## Get started"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "\n",
    "![](imgs/hybrid_and_rerank.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Prepare the data\n",
    "\n",
    "We use the Langchain WebBaseLoader to load documents from [blog sources](https://lilianweng.github.io/posts/2023-06-23-agent/) and split them into chunks using the RecursiveCharacterTextSplitter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import bs4\n",
    "from langchain_community.document_loaders import WebBaseLoader\n",
    "from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
    "\n",
    "from rag_utils.vanilla import vectorstore\n",
    "\n",
    "# Create a WebBaseLoader instance to load documents from web sources\n",
    "loader = WebBaseLoader(\n",
    "    web_paths=(\"https://lilianweng.github.io/posts/2023-06-23-agent/\",),\n",
    "    bs_kwargs=dict(\n",
    "        parse_only=bs4.SoupStrainer(\n",
    "            class_=(\"post-content\", \"post-title\", \"post-header\")\n",
    "        )\n",
    "    ),\n",
    ")\n",
    "# Load documents from web sources using the loader\n",
    "documents = loader.load()\n",
    "# Initialize a RecursiveCharacterTextSplitter for splitting text into chunks\n",
    "text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
    "\n",
    "# Split the documents into chunks using the text_splitter\n",
    "docs = text_splitter.split_documents(documents)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Build the chain\n",
    "\n",
    "We load the docs into milvus vectorstore, and build a milvus retriever."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "vectorstore.add_documents(docs)\n",
    "milvus_retriever = vectorstore.as_retriever()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "And build a bm25 retriever from the docs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_community.retrievers import BM25Retriever\n",
    "\n",
    "bm25_retriever = BM25Retriever.from_documents(docs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Build a vanilla RAG chain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "from rag_utils.vanilla import format_docs, rag_prompt, llm\n",
    "\n",
    "\n",
    "vanilla_rag_chain = (\n",
    "    {\"context\": milvus_retriever | format_docs, \"question\": RunnablePassthrough()}\n",
    "    | rag_prompt\n",
    "    | llm\n",
    "    | StrOutputParser()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Prepare hybrid_and_rerank_retriever."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_voyageai import VoyageAIRerank\n",
    "from rag_utils.hybrid_and_rerank import RerankerRunnable\n",
    "\n",
    "reranker = RerankerRunnable(\n",
    "    compressor=VoyageAIRerank(model=\"rerank-lite-1\"),\n",
    "    top_k=4,\n",
    ")\n",
    "\n",
    "hybrid_and_rerank_retriever = {\n",
    "    \"milvus_retrieved_doc\": milvus_retriever,\n",
    "    \"bm25_retrieved_doc\": bm25_retriever,\n",
    "    \"query\": RunnablePassthrough(),\n",
    "} | reranker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Build hybrid_and_rerank_chain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "\n",
    "hybrid_and_rerank_chain = (\n",
    "    {\n",
    "        \"context\": hybrid_and_rerank_retriever | format_docs,\n",
    "        \"question\": RunnablePassthrough(),\n",
    "    }\n",
    "    | rag_prompt\n",
    "    | llm\n",
    "    | StrOutputParser()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Test the chain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len(milvus_retrieved_doc) = 4\n",
      "len(bm25_retrieved_doc) = 4\n",
      "len(unique_documents) = 6\n",
      "\n",
      "[vanilla_result]:\n",
      "The models that use tools are TALM (Tool Augmented Language Models; Parisi et al. 2022), Toolformer (Schick et al. 2023), and MRKL (Modular Reasoning, Knowledge and Language; Karpas et al. 2022). These models are fine-tuned to learn to use external tool APIs and other external modules for additional capabilities.\n",
      "\n",
      "[hybrid_and_rerank_result]:\n",
      "The models that use tools are ChatGPT Plugins and OpenAI API for function calling, HuggingGPT, TALM (Tool Augmented Language Models), Toolformer, and MRKL (Modular Reasoning, Knowledge and Language). These models are equipped with the capability to use external tools or APIs to extend their functionalities.\n"
     ]
    }
   ],
   "source": [
    "query = \"Which model use tools?\"\n",
    "\n",
    "vanilla_result = vanilla_rag_chain.invoke(query)\n",
    "hybrid_and_rerank_result = hybrid_and_rerank_chain.invoke(query)\n",
    "print(\n",
    "    f\"\\n[vanilla_result]:\\n{vanilla_result}\\n\\n[hybrid_and_rerank_result]:\\n{hybrid_and_rerank_result}\"\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "In [hybrid_and_rerank_result], it answers with more ground truth:\n",
    "> ChatGPT with Plugins and OpenAI API function calling, HuggingGPT\n",
    "\n",
    "which does not appear in [vanilla result]."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hybrid_and_rerank_retriever.invoke(query)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "![](imgs/hybrid_and_rerank_res1.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# milvus_retriever.invoke(query)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "![](imgs/hybrid_and_rerank_res2.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "We can see that the hybrid retrieved results include `ChatGPT` and `HuggingGPT`. This document contains the word `tool`, which is consistent with the word in the query, that is where BM25 did better.\n",
    "\n",
    "Therefore, we can use bm25 as a supplement to vector retriever to improve the accuracy of recall."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}